package UDAF



import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/*
* 自定义聚合函数
* 求平均工资
* 1.计算所有人的工资总和
* 2.统计人数
* 3.工资总和除以人数 = 平均工资
* */
object UDAFTest extends UserDefinedAggregateFunction{
  //定义输入数据的数据类型
  override def inputSchema: StructType = StructType(
    Array(StructField("salary",DoubleType,true))
  )
  //定义辅助字段(缓存字段)的名称和数据类型
  //1.总工资
  //2.总人数
  override def bufferSchema: StructType = StructType(
    StructField("totalSalary",DoubleType,true)::
    StructField("count",IntegerType,true)::Nil
  )
  //定义输出数据类型
  override def dataType: DataType = DoubleType

  // 初始化辅助字段数据
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0,0.0)
    buffer.update(1,0)
  }
  //修改辅助字段中间状态的值
  //在局部做的统计
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val lastTotalSalary = buffer.getDouble(0)
    val lastCount = buffer.getInt(1)
    val currentSalary = input.getDouble(0)
    buffer.update(0,lastTotalSalary+currentSalary)
    buffer.update(1,lastCount+1)
  }
  //最后的结果是需要全局统计
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit ={
    val totalSalary1 = buffer1.getDouble(0)
    val count1 =buffer1.getInt(1)

    val totalSalary2 = buffer2.getDouble(0)
    val count2 = buffer2.getInt(1)

    buffer1.update(0,totalSalary1+totalSalary2)
    buffer1.update(1,count1+count2)
  }


  // 返回最后结果
  override def evaluate(buffer: Row): Any = {
    val totalSalary = buffer.getDouble(0)
    val count = buffer.getInt(1)
    totalSalary/count
  }
  //输入的数据类型和输出的数据类型是否一致
  override def deterministic: Boolean = true
}
